import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import torch
import os
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import cv2
import matplotlib.pyplot as plt
C:\Users\lisas\anaconda3\envs\xai_model_explanation\lib\site-packages\tqdm\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
predictions = pd.read_csv('predictions.csv')
labels = torch.load(os.path.join('data','y_test.pt'))
predictions['label'] = pd.Series(labels)
correct_pred = [np.sum(np.array(predictions.iloc[i,:50]) == predictions.iloc[i,50])/50 for i in range(len(predictions))]
predictions['% correct'] = pd.Series(correct_pred)
predictions['index'] = pd.Series(list(range(len(predictions))))
frequencies = []
variabilities = []
scores = []
for i in range(len(predictions)):
frequency = 0
score = 0
for j in range(49):
if predictions[str(j+1)].iloc[i] != predictions[str(j+2)].iloc[i]:
frequency += 1
if predictions[str(j+1)].iloc[i] != predictions['label'].iloc[i]:
score += 1
frequencies.append(frequency/49)
variability = len(set(predictions.iloc[i,:50]))/6
variabilities.append(variability)
if predictions['50'].iloc[i] != predictions['label'].iloc[i]:
score += 1
scores.append(score/50)
predictions['frequency'] = pd.Series(frequencies)
predictions['variability'] = pd.Series(variabilities)
predictions['misclassification'] = pd.Series(scores)
# sort predictions
start = 0
for label in range(6):
sub_frame = predictions.loc[predictions['label'] == label].sort_values('% correct', ascending=False)
predictions.iloc[start:start+len(sub_frame),:] = sub_frame
start += len(sub_frame)
predictions = predictions[::-1]
predictions.head()
| 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | ... | 47 | 48 | 49 | 50 | label | % correct | index | frequency | variability | misclassification | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 2992 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 4 | 4 | 4 | 4 | 5 | 0.0 | 2831 | 0.326531 | 0.500000 | 1.0 |
| 2991 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 5 | 0.0 | 2773 | 0.122449 | 0.333333 | 1.0 |
| 2990 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 5 | 0.0 | 2493 | 0.000000 | 0.166667 | 1.0 |
| 2989 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 1 | 1 | 1 | 1 | 5 | 0.0 | 2944 | 0.306122 | 0.333333 | 1.0 |
| 2988 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 5 | 0.0 | 2636 | 0.000000 | 0.166667 | 1.0 |
5 rows × 56 columns
test_acc = []
for i in range(50):
acc = np.sum(np.array(predictions[str(i+1)]) == np.array(predictions['label']))/len(predictions)*100
test_acc.append(acc)
val_loss = torch.load("training\\metrics\\validation_loss.pt")
train_loss = torch.load("training\\metrics\\train_loss.pt")
train_acc = torch.load("training\\metrics\\train_accuracy.pt")
def discrete_colorscale(bvals, colors):
"""
bvals - list of values bounding intervals/ranges of interest
colors - list of rgb or hex colorcodes for values in [bvals[k], bvals[k+1]],0<=k < len(bvals)-1
returns the plotly discrete colorscale
"""
if len(bvals) != len(colors)+1:
raise ValueError('len(boundary values) should be equal to len(colors)+1')
bvals = sorted(bvals)
nvals = [(v-bvals[0])/(bvals[-1]-bvals[0]) for v in bvals] #normalized values
dcolorscale = [] #discrete colorscale
for k in range(len(colors)):
dcolorscale.extend([[nvals[k], colors[k]], [nvals[k+1], colors[k]]])
return dcolorscale
bvals = [-0.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5]
colors = ['red', 'green', 'lightblue', 'yellow', 'darkblue', 'grey']
dcolorsc = discrete_colorscale(bvals, colors)
heatmap = go.Heatmap(x = list(range(1,len(predictions)+1)), z=predictions.iloc[:,:46], colorscale = dcolorsc, showscale = False)
heatmap_scale = go.Heatmap(x = list(range(1,len(predictions)+1)), z=np.array(predictions['label']).reshape((-1,1)), colorscale = dcolorsc, showscale = False)
freq = go.Bar(x=predictions['frequency'], orientation='h')
var = go.Bar(x=predictions['variability'], orientation='h')
miss = go.Bar(x=predictions['misclassification'], orientation='h')
val_loss_trace = go.Scatter({'x': list(range(1,47)),'y': val_loss[:46]}, line = dict(color='orange'))
train_loss_trace = go.Scatter({'x': list(range(1,47)),'y': train_loss[:46]}, line = dict(color='purple'))
test_acc_trace = go.Scatter({'x': list(range(1,47)),'y': test_acc[:46]}, line = dict(color='darkorange'))
train_acc_trace = go.Scatter({'x': list(range(1,47)),'y': train_acc[:46]}, line = dict(color='purple'))
confusion_matrix = np.array(pd.crosstab(predictions['label'], predictions['46']))
cm = go.Heatmap({'z': confusion_matrix/np.sum(confusion_matrix, axis=0),
'x':['buildings', 'forest', 'glacier', 'mountain', 'sea', 'street'],
'y':['buildings', 'forest', 'glacier', 'mountain', 'sea', 'street'],
'text': confusion_matrix/np.sum(confusion_matrix, axis=0),
'colorscale': 'greys'}, showscale = False)
fig = make_subplots(
rows=3, cols=5,
specs=[[{}, {}, {}, {}, {}],
[None, {}, {"rowspan": 2, "colspan": 3}, None, None],
[None, {}, None, None, None]],
row_heights=[0.7, 0.15, 0.15], column_widths=[0.02, 0.7, 0.1, 0.1, 0.1],
horizontal_spacing = 0.01, vertical_spacing = 0.06,
subplot_titles=('Labels', 'Confusion Evolution', 'Frequency', 'Variability', 'Misclassification Rate', 'Loss', 'Confusion Matrix', 'Accuracy'),
print_grid=False)
fig.add_trace(heatmap_scale, row=1, col=1)
fig.add_trace(heatmap, row=1, col=2)
fig.add_trace(freq, row=1, col=3)
fig.add_trace(var, row=1, col=4)
fig.add_trace(miss, row=1, col=5)
fig.add_trace(val_loss_trace, row=2, col=2)
fig.add_trace(train_loss_trace, row=2, col=2)
fig.add_trace(test_acc_trace, row=3, col=2)
fig.add_trace(train_acc_trace, row=3, col=2)
fig.add_trace(cm, row = 2, col = 3)
fig.update_layout(plot_bgcolor='white', height=1500, width=1500,
xaxis={'visible': False}, yaxis={'visible': False},
xaxis2={'visible': False, 'side': 'top'}, yaxis2={'visible': False},
xaxis3={'visible': True, 'range': [0, 1]}, yaxis3={'visible': False},
xaxis4={'visible': True, 'range': [0, 1]}, yaxis4={'visible': False},
xaxis5={'visible': True, 'range': [0, 1]}, yaxis5={'visible': False},
xaxis6={'visible': False}, yaxis6={'visible': True, 'range': [0, 3], 'title': 'Loss'},
xaxis7={'visible': True}, yaxis7={'visible': True, 'side': 'right'},
xaxis8={'visible': True, 'title': 'epochs'}, yaxis8={'visible': True, 'range': [0, 100], 'title': 'Accuracy(%)'},
showlegend=False)
fig.update_coloraxes(showscale = False)
fig.add_annotation(xref='paper', yref='paper',x=0.001, y=0.43,text="buildings", textangle = 270, showarrow = False, font=dict(size=16))
fig.add_annotation(xref='paper', yref='paper',x=0.001, y=0.52, text="forest", textangle = 270, showarrow = False, font=dict(size=16, color="white"))
fig.add_annotation(xref='paper', yref='paper',x=0.001, y=0.627,text="glacier", textangle = 270, showarrow = False, font=dict(size=16))
fig.add_annotation(xref='paper', yref='paper',x=0.001, y=0.766,text="mountain", textangle = 270, showarrow = False, font=dict(size=16))
fig.add_annotation(xref='paper', yref='paper',x=0.001, y=0.855,text="sea", textangle = 270, showarrow = False, font=dict(size=16, color="white"))
fig.add_annotation(xref='paper', yref='paper',x=0.001, y=0.967,text="street", textangle = 270, showarrow = False, font=dict(size=16, color="white"))
fig.add_annotation(xref='paper', yref='paper',x=0.54, y=0.3,text="validation loss", showarrow = False, font=dict(size=16, color="orange"))
fig.add_annotation(xref='paper', yref='paper',x=0.025, y=0.2,text="training loss", showarrow = False, font=dict(size=16, color="purple"))
fig.add_annotation(xref='paper', yref='paper',x=0.04, y=0.025,text="test accuracy", showarrow = False, font=dict(size=16, color="darkorange"))
fig.add_annotation(xref='paper', yref='paper',x=0.1, y=0.12,text="training accuracy", showarrow = False, font=dict(size=16, color="purple"))
fig.show()